﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Data.Linq.Mapping;
using System.Data.SqlClient;
using System.Reflection;

namespace GoodStuff.Data.Linq
{
    /// <summary>
    /// LinqBulkCopyReader
    /// </summary>
    /// <example>http://blogs.microsoft.co.il/blogs/aviwortzel/archive/2008/05/06/implementing-sqlbulkcopy-in-linq-to-sql.aspx</example>
    /// <typeparam name="TEntity"></typeparam>
    internal class HeavyOperationBulkCopyReader<TEntity> : HeavyOperationSqlBulkCopyReader
        where TEntity : class
    {
        public HeavyOperationBulkCopyReader(IEnumerable<TEntity> entities)
        {
            _Enumerator = entities.GetEnumerator();
            _ColumnMappingList = new List<ColumnMapping>();

            var entityType = entities.GetType().GetInterface("IEnumerable`1").GetGenericArguments()[0];
            _TableName = (entityType.GetCustomAttributes(typeof(TableAttribute), false) as TableAttribute[])[0].Name;

            var properties = entityType.GetProperties();
            int columnIndex = 0; //We need a counter for added column only (custom properties shouldn't be added)
            for (int index = 0; index < properties.Length; index++)
            {
                var property = properties[index];
                var columns = property.GetCustomAttributes(typeof(ColumnAttribute), false) as ColumnAttribute[];
                foreach (var column in columns)
                {
                    if ((!column.DbType.Contains("IDENTITY")) &&
                        (!column.IsDbGenerated) &&
                        (!column.IsVersion))
                    {
                        _ColumnMappingList.Add(new ColumnMapping()
                        {
                            ColumnIndex = columnIndex,
                            ColumnName = column.Name ?? property.Name,
                            ColumnGetter = row => property.GetValue(row, null)
                        });
                        columnIndex++;
                    }
                }
            }
        }

        private readonly IEnumerator<TEntity> _Enumerator;
        private readonly IList<ColumnMapping> _ColumnMappingList;
        private readonly string _TableName;

        public string TableName
        {
            get { return _TableName;}
        }
        
        public IEnumerable<string> Columns
        {
            get
            {
                return _ColumnMappingList.Select(column => column.ColumnName);
            }
        }

        public IDictionary<string, int> ColumnMappingList
        {
            get
            {
                return _ColumnMappingList.Select(column => new { column.ColumnName, column.ColumnIndex }).ToDictionary(x => x.ColumnName, y => y.ColumnIndex);
            }
        }

        public override bool Read()
        {
            return _Enumerator.MoveNext();
        }

        public override object GetValue(int i)
        {
            return _ColumnMappingList
                .Where(column => column.ColumnIndex == i)
                .Single()
                .ColumnGetter(_Enumerator.Current);
        }

        public override int FieldCount
        {
            get
            {
                return _ColumnMappingList.Count;
            }
        }

        public override int GetOrdinal(string name) 
        {
            return _ColumnMappingList
                .Where(column => column.ColumnName == name)
                .Single()
                .ColumnIndex;
        }

        private class ColumnMapping
        {
            public int ColumnIndex { get; set; }
            public string ColumnName { get; set; }
            public Func<object, object> ColumnGetter { get; set; }
        }
    }
}
